Loss function: Earth movel distance - The effort needed to make both distributions equal

  • Critics(discriminator) values not restricted to be between 0 and 1
  • Even for very different distributions, gradients are significant and high enough to drive the process in the right way
In [ ]:
import torch
import torchvision
import os
import PIL
import pdb
import numpy as np
import matplotlib.pyplot as plt

from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
from PIL import Image
In [ ]:
# OPTIONAL
!pip install wandb -qqq
import wandb
wandb.login(key='9e59c0e5f929ee5a223fead436cba53be4f90e1a')
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.9/6.9 MB 29.3 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.3/207.3 kB 27.7 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 289.6/289.6 kB 31.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.7/62.7 kB 9.6 MB/s eta 0:00:00
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin
wandb: WARNING If you're specifying your api key in code, ensure this code is not shared publicly.
wandb: WARNING Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
Out[ ]:
True
In [ ]:
def show(tensor, num=25, wandbactivation=0, name=''):
  data = tensor.detach().cpu()
  grid = make_grid(data[:num], nrow=5).permute(1,2,0)

  # optional
  # wandb - online activation
  if (wandbactivation==1 and wandbact==1):
      wandb.log({name:wandb.Image(grid.numpy().clip(0,1))})


  # cliping pixels to range(0,1)
  plt.imshow(grid.clip(0,1))
  plt.show()
In [ ]:
## hyperparameters and general parameters

n_epochs = 1000
batch_size = 128
lr = 1e-4

# z_dim - input noise latent vector dim
z_dim = 200
device = 'cuda'

cur_step = 0

# 5 cycles training of the critic, then 1 of the generator
# generally critic needs more training than the generator
crit_cycles = 5

gen_losses = []
crit_losses = []
show_step  = 35
save_step = 35

# optional, tracking stats online
wandbact = 1
In [ ]:
# optional wandb
%%capture
# experiment_name = wandb.util.generate_id()
experiment_name = 'MY EXP'
myrun = wandb.init(
    project='wgan',
    name=experiment_name,
    group=experiment_name,
    config={'optimizer':'adam',
            'model':'wgan gp',
            'epoch':'1000',
            'batch_size':128
            }
)
config = wandb.config
In [ ]:
# optional wandb
print(experiment_name)
In [ ]:
# generator model

class Generator(nn.Module):

  # d_dim - internal dimension for the output of the convolutional layers
  def __init__(self, z_dim=64, d_dim=16):
    super(Generator, self).__init__()
    self.z_dim = z_dim

    self.gen = nn.Sequential(

        # ConvTranspose2d: in_channels, out_channels, kernel_size, stride=1, padding=0
        # n - width or height
        # (n - 1) * stride - 2 * padding + kernel_size

        # generator starts 1x1 pixel and z_dim number of channels and it gives it dimensionality of latent space
        # starting with 200 channels bringing up channels to 512 and increasing size of the image

        # 1st block
        nn.ConvTranspose2d(in_channels=z_dim, out_channels=d_dim*32, kernel_size=4, stride=1, padding=0),
        # normalizing values for improving stability, num_features = out from conv layer
        nn.BatchNorm2d(num_features=d_dim*32),
        # applying nonlinearity
        nn.ReLU(True),


        # 2nd block
        nn.ConvTranspose2d(in_channels=d_dim*32, out_channels=d_dim*16, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(d_dim*16),
        nn.ReLU(inplace=True),

        # 3rd block
        nn.ConvTranspose2d(in_channels=d_dim*16, out_channels=d_dim*8, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(d_dim*8),
        nn.ReLU(inplace=True),

        # 4th block
        nn.ConvTranspose2d(in_channels=d_dim*8, out_channels=d_dim*4, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(d_dim*4),
        nn.ReLU(inplace=True),

        # 5th block
        nn.ConvTranspose2d(in_channels=d_dim*4, out_channels=d_dim*2, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(d_dim*2),
        nn.ReLU(inplace=True),

        # 6th block
        # in output layer we stay with ONLY 3 channels
        nn.ConvTranspose2d(in_channels=d_dim*2, out_channels=3, kernel_size=4, stride=2, padding=1),
        # no batch norm

        # results from -1 to 1
        nn.Tanh()

    )


  def forward(self, noise):
    # generator recives noise as input
    x = noise.view(len(noise), self.z_dim, 1, 1) # 128(batch) x 200(z_dim) x 1(height) x 1(width)

    return self.gen(x)


def gen_noise(num, z_dim, device='cuda'):
  return torch.randn(num, z_dim, device=device) # 128 x 200


# n - width or height
# nn.Conv2d: (n + 2 * pad - ks) // stride + 1
# nn.ConvTranspose2d: (n - 1) * stride - 2 * padding + kernel_size

(n - 1) * stride - 2 * padding + kernel_size

  • 1st step:
  • (1 - 1) * 1 - 2 * 0 + 4 = 4x4 image, channels: 200 to 512
  • 2nd step
  • (4 - 1) * 2 - 2 * 1 + 4 = 8x8 image, channels: 512 to 256
  • 3rd step
  • (8 - 1) * 2 - 2 * 1 + 4 = 16x16 image, channels: 256 to 128
  • 4th step
  • (16 - 1) * 2 - 2 * 1 + 4 = 32x32 image, channels: 128 to 64
  • 5th step
  • (32 - 1) * 2 - 2 * 1 + 4 = 64x64 image, channels: 64 to 32
  • 6th step
  • (64 - 1) * 2 - 2 * 1 + 4 = 128x128 image, channels: 32 to 3
In [ ]:
# critic model
# Conv2d: in_channels, out_channels, kernel_size, stride=1, padding=0
# (n + 2 * padding - kernel_size) // stride + 1

class Critic(nn.Module):
  def __init__(self, d_dim=16):
    super(Critic, self).__init__()

    self.crit = nn.Sequential(
        # 1st block
        nn.Conv2d(in_channels=3, out_channels=d_dim, kernel_size=4, stride=2, padding=1),

        # instead of batchnorm2d, normalizing according to the values of the whole instance insted of values of the batch
        nn.InstanceNorm2d(d_dim), # works the best

        # leaky keeps information, negative values have little slope(small negative numbers), but theyre not converted to 0
        nn.LeakyReLU(0.2),

        # 2nd block
        nn.Conv2d(in_channels=d_dim, out_channels=d_dim*2, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(d_dim*2),
        nn.LeakyReLU(0.2),

        # 3rd block
        nn.Conv2d(in_channels=d_dim*2, out_channels=d_dim*4, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(d_dim*4),
        nn.LeakyReLU(0.2),

        # 4th block
        nn.Conv2d(in_channels=d_dim*4, out_channels=d_dim*8, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(d_dim*8),
        nn.LeakyReLU(0.2),

        # 5th block
        nn.Conv2d(in_channels=d_dim*8, out_channels=d_dim*16, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(d_dim*16),
        nn.LeakyReLU(0.2),

        # 6th block
        # we return 1 value, either its fake or real
        # stride has to be 1 and padding 0
        nn.Conv2d(in_channels=d_dim*16, out_channels=1, kernel_size=4, stride=1, padding=0),

    )


  def forward(self, image):
    # image: 128(batch) x 3(channels) x 128(height) x 128(width)
    crit_pred = self.crit(image) # 128 x 1 x 1 x 1
    return crit_pred.view(len(crit_pred), -1) # 128(batch values) x 1(fake or real)

(n + 2 * padding - kernel_size) // stride + 1

  • 1st step
  • (128 + 2 * 1 - 4) //2 + 1 = 64x64 image, channels: 3 to 16
  • 2nd step
  • (64 + 2 * 1 - 4) //2 + 1 = 32x32 image, channels: 16 to 32
  • 3rd step
  • (32 + 2 * 1 - 4) //2 + 1 = 16x16 image, channels: 32 to 64
  • 4th step
  • (16 + 2 * 1 - 4) //2 + 1 = 8x8 image, channels: 64 to 128
  • 5th step
  • (8 + 2 * 1 - 4) //2 + 1 = 4x4 image, channels: 128 to 256
  • 6th step
  • (4 + 2 * 0 - 4) //1 + 1 = 1x1 image, channels: 256 to 1

Alternative way to initialize parameters

In [ ]:
def init_weights(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal(m.weight, 0.0, 0.02)
    torch.nn.init.constant(m.bias, 0)

  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal(m.weight, 0.0, 0.02)
    torch.nn.init.constant(m.bias, 0)

# Initializations - NOT HERE
# gen = gen.apply(init_weights)
# crit = crit.apply(init_weights)
In [ ]:
# loading dataset
import gdown, zipfile

# url = ''
path = '/content/drive/MyDrive'
download_path = f'{path}/img_align_celeba.zip'

if not os.path.exists(path):
  os.makedirs(path)

# gdown.download(url, download_path, quiet=False)

with zipfile.ZipFile(download_path, 'r') as ziphandler:
  ziphandler.extractall('.')
In [ ]:
!wget https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ
!unzip -q img_align_celeba.zip
--2024-06-17 10:52:45--  https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ
Resolving drive.google.com (drive.google.com)... 108.177.121.138, 108.177.121.100, 108.177.121.101, ...
Connecting to drive.google.com (drive.google.com)|108.177.121.138|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ’


          view?reso     [<=>                 ]       0  --.-KB/s               
view?resourcekey=0-     [ <=>                ]  88.03K  --.-KB/s    in 0.003s  

2024-06-17 10:52:45 (26.0 MB/s) - ‘view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ’ saved [90138]

unzip:  cannot find or open img_align_celeba.zip, img_align_celeba.zip.zip or img_align_celeba.zip.ZIP.
In [ ]:
# class Dataset(Dataset):
class Dataset(Dataset):
  def __init__(self, path, size=128, lim=10000):
    self.sizes = [size, size]

    # paths to the images
    items, labels = [], []

    for data in os.listdir(path)[:lim]:
      # path: './data/celeba/img_align_celeba'
      # data: '123213.img'
      item = os.path.join(path, data)
      items.append(item)
      labels.append(data)
    self.items=items
    self.labels=labels

  def __len__(self):
    return len(self.items)

  def __getitem__(self, idx):
    # open image idx
    data = PIL.Image.open(self.items[idx]).convert('RGB') # size of the image: fe. 1278, 121
    data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) # 128 x 128 x 3
    data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) # 3 x 128 x 128
    # from np to tensor for training, div = standarizing
    data = torch.from_numpy(data).div(255) # leaving values from 0 to 1
    return data, self.labels[idx]
In [ ]:
len(os.listdir('./img_align_celeba'))
Out[ ]:
202599
In [ ]:
# Dataset
data_path = './img_align_celeba'
ds = Dataset(data_path, size=128, lim=50000)

# DataLoader
dataloader = DataLoader(dataset=ds, batch_size=128, shuffle=True)

# Models
gen = Generator(z_dim=z_dim).to(device)
crit = Critic().to(device)

# Optimizers
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.9)) # betas - internal calculations, works well with this architecture
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(0.5, 0.9))

# Initializations - OPTIONAL
# gen = gen.apply(init_weights)
# crit = crit.apply(init_weights)

# wandb - OPTIONAL
if (wandbact==1):
  wandb.watch(gen, log_freq=100)
  wandb.watch(crit, log_freq=100)

x, y = next(iter(dataloader))
show(x)
No description has been provided for this image
In [ ]:
# gradient penalty calculation

def get_gp(real, fake, crit, alpha, gamma=10): # alpha does random interpolations, gamma stands for intensity of gp regularizations
  mix_images = real * alpha + fake * (1-alpha) # 128(batch) x 3 x 128 x 128, linear interpolation
  mix_scores = crit(mix_images) # predictions: 128(batch) x 1

  # we want to penalize gradients that are to large
  # computing and returning the sum of the gradients of the outputs with respect to the inputs
  gradient = torch.autograd.grad(
      inputs = mix_images,
      outputs = mix_scores,

      # puting ones to take into account all the grades and outputs
      grad_outputs = torch.ones_like(mix_scores),
      retain_graph=True,
      create_graph=True,

  )[0] # return first batch 128(bs) x 3 x 128 x 128

  gradient = gradient.view(len(gradient), -1) # 128 x 49512(128x128x3)
  gradient_norm = gradient.norm(2, dim=1) # L2 norm
  gp = gamma * ((gradient_norm - 1)**2).mean()

  return gp
In [ ]:
# saving and loading checkpoints

if not os.path.exists('./content/drive/MyDrive/training_data/'):
    os.makedirs('./content/drive/MyDrive/training_data/')
root_path='./content/drive/MyDrive/training_data/'

def save_checkpoint(name):
  torch.save({
      'epoch':epoch,
      'model_state_dict':gen.state_dict(),
      'optimizer_state_dict':gen_opt.state_dict(),
  }, f'{root_path}G-{name}.pkl')

  torch.save({
      'epoch':epoch,
      'model_state_dict':crit.state_dict(),
      'optimizer_state_dict':crit_opt.state_dict(),
  }, f'{root_path}C-{name}.pkl')

  print('Saved checkpoint')


def load_checkpoint(name):
  # generator
  # loading file
  checkpoint = torch.load(f'{root_path}G-{name}.pkl')
  # loading values to the model
  gen.load_state_dict(checkpoint['model_state_dict'])
  gen_opt.load_state_dict(checkpoint['optimizer_state_dict'])

  # critic
  checkpoint = torch.load(f'{root_path}C-{name}.pkl')
  crit.load_state_dict(checkpoint['model_state_dict'])
  crit_opt.load_state_dict(checkpoint['optimizer_state_dict'])

  print('Loaded checkpoint')
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-1-39c5fc667a2d> in <cell line: 3>()
      1 # saving and loading checkpoints
      2 
----> 3 if not os.path.exists('./content/drive/MyDrive/training_data/'):
      4     os.makedirs('./content/drive/MyDrive/training_data/')
      5 root_path='./content/drive/MyDrive/training_data/'

NameError: name 'os' is not defined
In [ ]:
#!cp C-final* ./data/
#!cp G-final* ./data/
epoch=1
save_checkpoint('test')
load_checkpoint('test')
Saved checkpoint
Loaded checkpoint
In [ ]:
# Training loop

for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    cur_bs = len(real) # 128
    real = real.to(device)


    ## Critic
    mean_crit_loss = 0
    for _ in range(crit_cycles):

      # zeroing gradient of the optimizer
      crit_opt.zero_grad()

      noise = gen_noise(cur_bs, z_dim)
      fake = gen(noise)
      # detaching for not affecting the parameters of the generator
      crit_fake_pred = crit(fake.detach())
      crit_real_pred = crit(real)

      # alpha vector (numbers size of the batch)
      alpha = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True) # 128 x 1 x 1 x 1

      # calculating gradient penalty
      gp = get_gp(real, fake.detach(), crit, alpha)

      # calculating loss
      crit_loss = crit_fake_pred.mean() - crit_real_pred.mean() + gp

      #.item - taking only the number from the tensor
      mean_crit_loss += crit_loss.item() / crit_cycles

      # optimizer backpropagation
      crit_loss.backward(retain_graph=True)
      crit_opt.step()

    # list of losses values
    crit_losses += [mean_crit_loss]


    ## Generator

    # zeroing gradient of the optimizer
    gen_opt.zero_grad()

    # creating noise 128 x 200
    noise = gen_noise(cur_bs, z_dim)

    # passing noise through generator
    fake = gen(noise)

    # passing them through critic
    crit_fake_pred = crit(fake)

    # negative of the pred of the critic
    gen_loss = -crit_fake_pred.mean()

    # backpropagation
    gen_loss.backward()

    # updating the parameters of the generator
    gen_opt.step()

    gen_losses+=[gen_loss.item()]

    ## Statistics

    if (wandb==1):
      wandb.log(
          {'Epoch':epoch,
           'Step':cur_step,
           'Critic loss':mean_crit_loss,
           'Gen loss':gen_loss,
           }
      )

    if cur_step % save_step == 0 and cur_step > 0:
      print('Saving checkpoint:', cur_step, save_step)
      # best to save the files with the different names fe. nr of epoch
      save_checkpoint('latest')

    if (cur_step % show_step == 0 and cur_step > 0):
      show(fake, wandbactivation=1, name='fake')
      show(real, wandbactivation=1, name='real')

      gen_mean = sum(gen_losses[-show_step:]) / show_step
      crit_mean = sum(crit_losses[-show_step:]) / show_step
      print(f'Epoch: {epoch}, step: {cur_step}, Generator loss: {gen_loss}, Critic loss: {crit_loss}')

      plt.plot(range(len(gen_losses)),
               torch.Tensor(gen_losses),
               label='Generator loss')

      plt.plot(range(len(crit_losses)),
               torch.Tensor(crit_losses),
               label='Critic loss')

      plt.ylim(-200,200)
      plt.legend()
      plt.show()
    cur_step += 1
  0%|          | 0/391 [00:00<?, ?it/s]
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Saving checkpoint: 35 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 35, Generator loss: 15.25752067565918, Critic loss: -21.56631088256836
No description has been provided for this image
Saving checkpoint: 70 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 70, Generator loss: 25.56017303466797, Critic loss: -29.439315795898438
No description has been provided for this image
Saving checkpoint: 105 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 105, Generator loss: 24.86949920654297, Critic loss: -23.476028442382812
No description has been provided for this image
Saving checkpoint: 140 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 140, Generator loss: 26.923324584960938, Critic loss: -22.520244598388672
No description has been provided for this image
Saving checkpoint: 175 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 175, Generator loss: 25.521202087402344, Critic loss: -18.931087493896484
No description has been provided for this image
Saving checkpoint: 210 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 210, Generator loss: 20.612937927246094, Critic loss: -19.237375259399414
No description has been provided for this image
Saving checkpoint: 245 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 245, Generator loss: 15.082696914672852, Critic loss: -14.787302017211914
No description has been provided for this image
Saving checkpoint: 280 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 280, Generator loss: 20.040992736816406, Critic loss: -13.67352294921875
No description has been provided for this image
Saving checkpoint: 315 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 315, Generator loss: 16.89620018005371, Critic loss: -14.186975479125977
No description has been provided for this image
Saving checkpoint: 350 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 350, Generator loss: 19.051918029785156, Critic loss: -12.683650016784668
No description has been provided for this image
Saving checkpoint: 385 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 385, Generator loss: 16.979228973388672, Critic loss: -14.533666610717773
No description has been provided for this image
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  0%|          | 0/391 [00:00<?, ?it/s]
Saving checkpoint: 420 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 420, Generator loss: 17.686763763427734, Critic loss: -14.1609525680542
No description has been provided for this image
Saving checkpoint: 455 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 455, Generator loss: 18.91667938232422, Critic loss: -12.47118091583252
No description has been provided for this image
Saving checkpoint: 490 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 490, Generator loss: 18.239500045776367, Critic loss: -10.373126029968262
No description has been provided for this image
Saving checkpoint: 525 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 525, Generator loss: 16.01999282836914, Critic loss: -11.49423885345459
No description has been provided for this image
Saving checkpoint: 560 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 560, Generator loss: 15.156608581542969, Critic loss: -12.317692756652832
No description has been provided for this image
Saving checkpoint: 595 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 595, Generator loss: 16.044208526611328, Critic loss: -10.365501403808594
No description has been provided for this image
Saving checkpoint: 630 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 630, Generator loss: 18.690975189208984, Critic loss: -10.940778732299805
No description has been provided for this image
Saving checkpoint: 665 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 665, Generator loss: 17.979543685913086, Critic loss: -10.223814010620117
No description has been provided for this image
Saving checkpoint: 700 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 700, Generator loss: 14.435104370117188, Critic loss: -13.017997741699219
No description has been provided for this image
Saving checkpoint: 735 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 735, Generator loss: 19.215503692626953, Critic loss: -8.568474769592285
No description has been provided for this image
Saving checkpoint: 770 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 1, step: 770, Generator loss: 22.676485061645508, Critic loss: -13.234884262084961
No description has been provided for this image
  0%|          | 0/391 [00:00<?, ?it/s]
Saving checkpoint: 805 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 805, Generator loss: 20.212373733520508, Critic loss: -9.45648193359375
No description has been provided for this image
Saving checkpoint: 840 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 840, Generator loss: 19.213638305664062, Critic loss: -9.419897079467773
No description has been provided for this image
Saving checkpoint: 875 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 875, Generator loss: 17.558929443359375, Critic loss: -9.927236557006836
No description has been provided for this image
Saving checkpoint: 910 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 910, Generator loss: 23.620121002197266, Critic loss: -10.862022399902344
No description has been provided for this image
Saving checkpoint: 945 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 945, Generator loss: 18.534482955932617, Critic loss: -9.245712280273438
No description has been provided for this image
Saving checkpoint: 980 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 980, Generator loss: 23.425382614135742, Critic loss: -10.018635749816895
No description has been provided for this image
Saving checkpoint: 1015 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 1015, Generator loss: 20.67441177368164, Critic loss: -9.165563583374023
No description has been provided for this image
Saving checkpoint: 1050 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 1050, Generator loss: 19.643081665039062, Critic loss: -8.507790565490723
No description has been provided for this image
Saving checkpoint: 1085 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 1085, Generator loss: 20.506515502929688, Critic loss: -10.645575523376465
No description has been provided for this image
Saving checkpoint: 1120 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 1120, Generator loss: 24.903501510620117, Critic loss: -8.776479721069336
No description has been provided for this image
Saving checkpoint: 1155 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 2, step: 1155, Generator loss: 25.28638458251953, Critic loss: -10.320352554321289
No description has been provided for this image
  0%|          | 0/391 [00:00<?, ?it/s]
Saving checkpoint: 1190 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1190, Generator loss: 16.862964630126953, Critic loss: -9.55805778503418
No description has been provided for this image
Saving checkpoint: 1225 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1225, Generator loss: 16.591989517211914, Critic loss: -8.997180938720703
No description has been provided for this image
Saving checkpoint: 1260 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1260, Generator loss: 20.295799255371094, Critic loss: -8.942169189453125
No description has been provided for this image
Saving checkpoint: 1295 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1295, Generator loss: 16.82471466064453, Critic loss: -8.582033157348633
No description has been provided for this image
Saving checkpoint: 1330 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1330, Generator loss: 19.550825119018555, Critic loss: -8.895916938781738
No description has been provided for this image
Saving checkpoint: 1365 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1365, Generator loss: 19.583209991455078, Critic loss: -8.709506034851074
No description has been provided for this image
Saving checkpoint: 1400 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1400, Generator loss: 18.479915618896484, Critic loss: -9.812521934509277
No description has been provided for this image
Saving checkpoint: 1435 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1435, Generator loss: 18.085033416748047, Critic loss: -8.537590026855469
No description has been provided for this image
Saving checkpoint: 1470 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1470, Generator loss: 15.994598388671875, Critic loss: -8.235672950744629
No description has been provided for this image
Saving checkpoint: 1505 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1505, Generator loss: 17.538623809814453, Critic loss: -9.286766052246094
No description has been provided for this image
Saving checkpoint: 1540 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 3, step: 1540, Generator loss: 14.294757843017578, Critic loss: -10.109136581420898
No description has been provided for this image
  0%|          | 0/391 [00:00<?, ?it/s]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-18-d641a5009044> in <cell line: 3>()
      2 
      3 for epoch in range(n_epochs):
----> 4   for real, _ in tqdm(dataloader):
      5     cur_bs = len(real) # 128
      6     real = real.to(device)

/usr/local/lib/python3.10/dist-packages/tqdm/notebook.py in __iter__(self)
    248         try:
    249             it = super().__iter__()
--> 250             for obj in it:
    251                 # return super(tqdm...) will not catch exception
    252                 yield obj

/usr/local/lib/python3.10/dist-packages/tqdm/std.py in __iter__(self)
   1179 
   1180         try:
-> 1181             for obj in iterable:
   1182                 yield obj
   1183                 # Update and possibly print the progressbar.

/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    629                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630                 self._reset()  # type: ignore[call-arg]
--> 631             data = self._next_data()
    632             self._num_yielded += 1
    633             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    673     def _next_data(self):
    674         index = self._next_index()  # may raise StopIteration
--> 675         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676         if self._pin_memory:
    677             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     49                 data = self.dataset.__getitems__(possibly_batched_index)
     50             else:
---> 51                 data = [self.dataset[idx] for idx in possibly_batched_index]
     52         else:
     53             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     49                 data = self.dataset.__getitems__(possibly_batched_index)
     50             else:
---> 51                 data = [self.dataset[idx] for idx in possibly_batched_index]
     52         else:
     53             data = self.dataset[possibly_batched_index]

<ipython-input-12-d25a227d10ee> in __getitem__(self, idx)
     22     # open image idx
     23     data = PIL.Image.open(self.items[idx]).convert('RGB') # size of the image: fe. 1278, 121
---> 24     data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) # 128 x 128 x 3
     25     data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) # 3 x 128 x 128
     26     # from np to tensor for training, div = standarizing

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

/usr/local/lib/python3.10/dist-packages/torchvision/transforms/transforms.py in forward(self, img)
    352             PIL Image or Tensor: Rescaled image.
    353         """
--> 354         return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
    355 
    356     def __repr__(self) -> str:

/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional.py in resize(img, size, interpolation, max_size, antialias)
    466             warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
    467         pil_interpolation = pil_modes_mapping[interpolation]
--> 468         return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
    469 
    470     return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)

/usr/local/lib/python3.10/dist-packages/torchvision/transforms/_functional_pil.py in resize(img, size, interpolation)
    248         raise TypeError(f"Got inappropriate size arg: {size}")
    249 
--> 250     return img.resize(tuple(size[::-1]), interpolation)
    251 
    252 

/usr/local/lib/python3.10/dist-packages/PIL/Image.py in resize(self, size, resample, box, reducing_gap)
   2190                 )
   2191 
-> 2192         return self._new(self.im.resize(size, resample, box))
   2193 
   2194     def reduce(self, factor, box=None):

KeyboardInterrupt: 

10000 / 128 = 78.125 - 79 steps per epoch

In [ ]:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()